TP-invariant Training: bitwise-identical training across TP degrees and GPU architecture#2977
TP-invariant Training: bitwise-identical training across TP degrees and GPU architecture#2977jinzex wants to merge 4 commits into
Conversation
|
I don't understand - how can you get the bitwise identical results across different GPU architectures if you don't use some custom GEMM implementation? cuBLAS does not provide any bitwise guarantees even across versions of the library and here we are looking at the different GPU architectures... |
|
Thanks @ptrendx . It's done by enabling BIK (Batch Invariant Kernels, GEMM will run with Trition) https://github.com/NVIDIA/Megatron-LM/blob/dev/megatron/core/transformer/custom_layers/batch_invariant_kernels.py GPU architecture invariant training is helpful for derisking new hardwares for our customers (ie from Hopper -> Blackwell -> Rubin and future) |
Gated on NVTE_TP_INVARIANT_MODE=1 (default off; stock paths unchanged). - module/linear.py: row-parallel FWD + BWD full GEMM matching TP=1 K-dim accumulation. - module/layernorm_linear.py: column-parallel BWD dgrad full GEMM with gated deinterleave for SwiGLU FC1 (partition_stride > 1). Companion Megatron-LM PR (gates this code path via env var): NVIDIA/Megatron-LM#4740. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Jinze Xue <jinzex@nvidia.com>
- Extract row-parallel fwd / column-parallel dgrad GEMMs to
module/_tp_invariant.py (mirrors _common.py precedent); main-file
call sites become helper calls.
- Drop NVTE_FP32_TP_REDUCE (separable feature; ~15 lines removed).
- Add tests/pytorch/distributed/{test,run}_tp_invariant.py: 20 cases
covering Linear with/without NVTE_TP_INVARIANT_MODE × parallel_mode
× sp × tp_size, plus LayerNormLinear partition_stride=2 (SwiGLU FC1
deinterleave). Reuses TestDistributedLinearBase from
run_numerics_exact.py (extended with partition_stride support).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Jinze Xue <jinzex@nvidia.com>
d233844 to
da8c6a5
Compare
for more information, see https://pre-commit.ci
|
Wait, so how does that work - the kernels that you pointed to are in Megatron, so what is the purpose of this PR in TE then? The feature should be self-contained, so if you want to enable this feature in TE then it should also include those kernels then. |
Importing ``general_gemm`` directly into ``_tp_invariant.py`` bypasses downstream monkey-patches that rebind it in caller modules — e.g., Megatron-LM's batch-invariant kernels patch the ``general_gemm`` symbol inside ``module.linear`` and ``module.layernorm_linear`` (their hardcoded target list), but not our new ``module._tp_invariant``. Result: the helper silently called the unpatched cuBLAS path, producing different bits than the BIK Triton path used elsewhere. Pass ``gemm_fn`` from the caller so the helper uses whichever ``general_gemm`` binding the caller's namespace holds. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Jinze Xue <jinzex@nvidia.com>
Description
Adds TP-invariant GEMM for bitwise-identical fwd/bwd training across TP degrees (TP=1/2/4/8) and GPU architecture (H100 ≡ B300), gated by
NVTE_TP_INVARIANT_MODE=1(default off; stock paths unchanged).Companion Megatron-LM PR: NVIDIA/Megatron-LM#4740 (gates this code path via env var, provides E2E validation scripts).
Type of change
Changes
module/linear.py: row-parallel FWD + BWD full GEMM matching TP=1 K-dim accumulation.module/layernorm_linear.py: column-parallel BWD dgrad full GEMM with gated deinterleave for SwiGLU FC1 (partition_stride > 1).Checklist: